{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plotting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Nearly all of the `hypertools` functionality may be accessed through the main `plot` function.  This design enables complex data analysis, data manipulation, and plotting to be carried out in a single function call.  To use it, simply pass your samples by features dataset(s) [in the form of a numpy array, pandas dataframe, or (mixed) list] to the `plot` function. Let's explore!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Hypertools and other libraries for tutorial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import hypertools as hyp\n",
    "import numpy as np\n",
    "import scipy\n",
    "import pandas as pd\n",
    "from scipy.linalg import toeplitz\n",
    "from copy import copy\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load your data\n",
    "\n",
    "We will load one of the sample datasets. This dataset consists of 8,124 samples of mushrooms with various text features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "geo = hyp.load('mushrooms')\n",
    "mushrooms = geo.get_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can peek at the first few rows of the dataframe using the pandas function `head()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mushrooms.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot with default settings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Hypertools` can plot your high-dimensional data quickly and easily with little user-generated input. By default, `hypertools` automatically reduces your data via incremental principal component analysis (if dimensions > 3) and plots plots a 3D line plot where the axes represent the top 3 principal components of the dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geo = hyp.plot(mushrooms) # plots a line"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, hypertools assumes you are passing in a timeseries, and so it plots a trajectory of the data evolving over time.  If you  aren't visualizing a timeseries, you can instead plot the individual observations as dots or other symbols by specifying an appropriate format style.\n",
    "\n",
    "To show the individual points, simply pass the `'.'` format string in the second argument position, or in any position using `fmt='.'`; the format string is parsed by [matplotlib](http://matplotlib.org/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geo = hyp.plot(mushrooms, '.') # plots dots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geo = hyp.plot(mushrooms, fmt = 'b*') # plots blue asterisks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot in 2D "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also opt to plot high dimensional data in two dimensional space, rather than 3D, by passing the `ndims` argument."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geo = hyp.plot(mushrooms, '.', ndims=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using other dimensionality reduction algorithms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To explore a data reduction method aside from the default (PCA), use `reduce` argument. Here, we pass the reduce argument a string.\n",
    "\n",
    "Other supported reduction models include: PCA, IncrementalPCA, SparsePCA, MiniBatchSparsePCA, KernelPCA, FastICA, FactorAnalysis, TruncatedSVD, DictionaryLearning, MiniBatchDictionaryLearning, TSNE, Isomap, SpectralEmbedding, LocallyLinearEmbedding, MDS, UMAP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geo = hyp.plot(mushrooms, '.', reduce='SparsePCA')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameter Specifications"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For finer control of the parameters, you can pass the reduce argument a dictionary (see scikit learn documentation about parameter options for specific reduction techniques)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geo = hyp.plot(mushrooms, '.', reduce={'model' : 'PCA', 'params' : {'whiten' : True}})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Coloring by hue"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To color your datapoints by group labels, pass the `hue` argument. It accepts strings, ints, and floats, or a list of these. You must pass hue the same number of labels as you have rows in your data matrix.\n",
    "\n",
    "Here, we group the data in five different chunks of equal size (size #points / 5) for simplicity. Note that we pass ints, strings, floats, and None in the same list to the hue argument."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split = int(mushrooms.shape[0]/ 5)\n",
    "hue = [1]*split + ['two']*split + [3.0]*split + [None]*split + ['four']*split\n",
    "geo_group = hyp.plot(mushrooms, '.', hue=hue)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adding a legend"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When coloring, you may want a legend to indicate group type. Passing `legend=True` will generate the legend based on your groupings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split = int(mushrooms.shape[0]/5)\n",
    "hue = [1]*split + ['two']*split + [3.0]*split + [None]*split + ['four']*split\n",
    "geo_hue = hyp.plot(mushrooms, '.', hue=hue, legend=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interpolating missing data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Missing data points? No problem! `Hypertools` will fill missing values via probabalistic principal components analysis (PPCA). Here, we generate a small synthetic dataset, remove a few values, then use PPCA to infer those missing values. Then, we plot the original data and the interpolated data, for comparison.  The one exception is that in cases where the entire data sample (row) is nans.  In this scenario, there is no data for PPCA to base its guess on, so the inference will fail."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# simulate data\n",
    "K = 10 - toeplitz(np.arange(10))\n",
    "data1 = np.cumsum(np.random.multivariate_normal(np.zeros(10), K, 250), axis=0)\n",
    "data2 = copy(data1)\n",
    "\n",
    "# randomly remove 5% of the data\n",
    "missing = .01\n",
    "inds = [(i,j) for i in range(data1.shape[0]) for j in range(data1.shape[1])]\n",
    "missing_data = [inds[i] for i in np.random.choice(int(len(inds)), int(len(inds)*missing))]\n",
    "for i,j in missing_data:\n",
    "    data2[i,j]=np.nan\n",
    "\n",
    "# reduce the data\n",
    "data1_r,data2_r = hyp.reduce([data1, data2], ndims=3)\n",
    "\n",
    "# pull out missing inds\n",
    "missing_inds = hyp.tools.missing_inds(data2)\n",
    "missing_data = data2_r[missing_inds, :]\n",
    "\n",
    "# plot\n",
    "geon_nan = hyp.plot([data1_r, data2_r, missing_data], ['-', '--', '*'],\n",
    "         legend=['Full', 'Missing', 'Missing Points'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Labeling plotted points"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `labels` argument accepts a list of labels for each point, which must be the same length as the data (the number of rows). If no label is wanted for a particular point, simply input `None`. In this example, we have made use of `None` in order to label only three points of interest (the first, second, and last in our set). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_unlabeled = int(mushrooms.shape[0])-3\n",
    "labeling = ['a','b'] + [None]*num_unlabeled + ['c']\n",
    "label = hyp.plot(mushrooms, '.', labels = labeling)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Clustering"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hypertools can also auto-cluster your datapoints with the `n_clusters` argument. To implement, simply set `n_clusters` to an integer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geo_cluster = hyp.plot(mushrooms, '.', n_clusters = 6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Normalization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For quick, easy data normalization of the input data, pass the normalize argument.\n",
    "\n",
    "You can pass the following arguments as strings: \n",
    "+ across - columns z-scored across lists (default)\n",
    "+ within - columns z-scored within each list\n",
    "+ row - each row z-scored "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "geo_cluster = hyp.plot(mushrooms, '.', normalize = 'within')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aligning datasets with different coordinate systems"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also align multiple datasets using the hypertools plot function in order to visualize data in a common space. This is useful, if you have more than one high-dimensional dataset that is related to the same thing.  For example, consider a brain imaging (fMRI) dataset comprised of multiple subjects watching the same movie. Voxel A in subject 1 may not necessarily be Voxel A in subject 2.  Alignment allows you to rotate and scale multiple datasets so they are in maximal alignment with one another.\n",
    "\n",
    "To do so, pass one of the following strings to the align argument (as shown below):\n",
    "\n",
    "+ `hyper` - hyperalignment algorithm (default) See: http://haxbylab.dartmouth.edu/publications/HGC+11.pdf\n",
    "+ `SRM` - shared response model algorithm. See: https://papers.nips.cc/paper/5855-a-reduced-dimension-fmri-shared-response-model.pdf\n",
    "\n",
    "Below, is a simple example of a spiral."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load example data\n",
    "geo = hyp.load('spiral')\n",
    "geo.plot(title='Before Alignment')\n",
    "\n",
    "# use procrusted to align the data\n",
    "source, target = geo.get_data()\n",
    "aligned = [hyp.tools.procrustes(source, target), target]\n",
    "\n",
    "# after alignment\n",
    "geo_aligned = hyp.plot(aligned, ['-','--'], title='After alignment')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To save a plot created with hypertools, simply pass the `save_path` argument."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# geo_cluster = hyp.plot(mushrooms, '.', save_path='cluster_plot.pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plotting text using semantic models\n",
    "\n",
    "In addition to numerical data, `hypertools` supports the plotting of text data by fitting the data to a semantic model.  We'll load in an example text dataset to get started which is comprised of all State of the Union Addresses from 1989-2017."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": "# Create sample State of the Union text data for demonstration\nsample_sotus_speeches = [\n    \"Tonight I can report to the nation that America is stronger and more secure.\",\n    \"We gather tonight knowing that this generation has been tested by crisis.\",\n    \"As we work together to advance America's interests, we must recognize threats.\",\n    \"Education is the great equalizer in America for all our children.\"\n]"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, the text data will be transformed using a Latent Dirichlet Model trained on a sample of wikipedia pages. Simply pass the list of text data to the `plot` function, and under the hood it will be transformed to a topic vector and then reduced for plotting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": "hyp.plot(sample_sotus_speeches)"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}